{
    "cells": [
        {
            "cell_type": "code",
            "execution_count": 1,
            "metadata": {},
            "outputs": [
                {
                    "name": "stderr",
                    "output_type": "stream",
                    "text": [
                        "Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias']\n",
                        "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
                        "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
                    ]
                }
            ],
            "source": [
                "from transformers import AutoTokenizer,AutoModel\n",
                "import torch\n",
                "import torch.cuda as cuda\n",
                "#加载分词器\n",
                "tokenizer = AutoTokenizer.from_pretrained('hfl/chinese-roberta-wwm-ext')\n",
                "#加载预训练模型\n",
                "pretrained = AutoModel.from_pretrained('hfl/chinese-roberta-wwm-ext')\n",
                "\n",
                "device = torch.device('cuda' if cuda.is_available() else 'cpu')\n",
                "# pretrained = pretrained.to(device)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 2,
            "metadata": {},
            "outputs": [],
            "source": [
                "# 定义下游模型\n",
                "class Model(torch.nn.Module):\n",
                "    def __init__(self):\n",
                "        super().__init__()\n",
                "        self.tuning = False\n",
                "        self.pretrained = None\n",
                "\n",
                "        self.rnn = torch.nn.GRU(768, 768, batch_first=True)\n",
                "        self.fc = torch.nn.Linear(768, 8)\n",
                "\n",
                "    def forward(self, inputs):\n",
                "        if self.tuning:\n",
                "            out = self.pretrained(**inputs).last_hidden_state\n",
                "        else:\n",
                "            with torch.no_grad():\n",
                "                out = pretrained(**inputs).last_hidden_state\n",
                "\n",
                "        out, _ = self.rnn(out)\n",
                "\n",
                "        out = self.fc(out).softmax(dim=2)\n",
                "\n",
                "        return out\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 8,
            "metadata": {},
            "outputs": [],
            "source": [
                "model_load = torch.load('ner5.model')\n",
                "model_load.eval()    \n",
                "model_load.to('cpu')\n",
                "\n",
                "def custom_predict(sentence):\n",
                "\n",
                "    inputs = tokenizer.encode_plus([sentence],\n",
                "                                   truncation=True,\n",
                "                                   padding=True,\n",
                "                                   return_tensors='pt',\n",
                "                                   is_split_into_words=True)\n",
                "\n",
                "    with torch.no_grad():\n",
                "        outputs = model_load(inputs)\n",
                "\n",
                "    # [b, lens] -> [b, lens]\n",
                "    preds = outputs.argmax(dim=2)[0]\n",
                "\n",
                "    result = ''\n",
                "\n",
                "    res = [set() for _ in range(3)]\n",
                "    tmp = ''\n",
                "    current_flag = -1\n",
                "\n",
                "    for i in range(len(preds)):\n",
                "        if inputs['attention_mask'][0][i] == 1:\n",
                "            result += tokenizer.decode(inputs['input_ids'][0][i])+' '\n",
                "            result += str(preds[i].item())+' '\n",
                "            num = preds[i].item()\n",
                "            # num 不为0和7和#表示该词为关键词\n",
                "            if (num != 0 and num != 7 and num != '#'):\n",
                "                # 关键词开始为奇数\n",
                "                if (num & 1):\n",
                "                    # 将形如 广5东6广5州6 拆成两个词\n",
                "                    if (len(tmp) > 1):\n",
                "                        res[(current_flag-1)//2].add(tmp)\n",
                "                        tmp = ''\n",
                "                    current_flag = num\n",
                "                    tmp += tokenizer.decode(inputs['input_ids'][0][i])\n",
                "\n",
                "                    # 防止形如 X4X4 出现\n",
                "                elif (num & 1 == 0 and current_flag != -1):\n",
                "                    tmp += tokenizer.decode(inputs['input_ids'][0][i])\n",
                "            else:\n",
                "                if (len(tmp) > 1):\n",
                "                    # current_flag 1对应姓名下标0，3对应组织下表1，5对应地点下表2\n",
                "                    res[(current_flag-1)//2].add(tmp)\n",
                "                    tmp = ''\n",
                "                    current_flag = -1\n",
                "\n",
                "    return res"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 9,
            "metadata": {},
            "outputs": [
                {
                    "name": "stdout",
                    "output_type": "stream",
                    "text": [
                        "姓名[],组织['广州二中高中'],地点[]\n"
                    ]
                }
            ],
            "source": [
                "input_sentence = \"2016.9-2019.7          广州二中           高中\"\n",
                "output_prediction = custom_predict(input_sentence)\n",
                "print(\n",
                "    f'姓名{list(output_prediction[0])},组织{list(output_prediction[1])},地点{list(output_prediction[2])}')\n"
            ]
        }
    ],
    "metadata": {
        "kernelspec": {
            "display_name": "Python 3",
            "language": "python",
            "name": "python3"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.11.0"
        },
        "orig_nbformat": 4
    },
    "nbformat": 4,
    "nbformat_minor": 2
}